# ------------------------------------------------------------------------------
# Fits GBMs
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bsub -q big -R "rusage[mem=25000]" bash 04_fit_gbms.sh
# ------------------------------------------------------------------------------

# Seeding ----------------------------------------------------------------------
set.seed(1)

# Libraries --------------------------------------------------------------------
message("Loading libraries...")
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(Matrix)
library(xgboost)
library(optparse)
library(here) # here() relative filepaths
library(testit) # assert functtion

u <- modules::use(here("lib", "util.R"))
temp <- here("code", "08_train_split_model", "temp")

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
x <- readRDS(paths$features$train_tested)
ids <- readRDS(file.path(temp, "split_train_cohort.rds"))

# Subset Data ------------------------------------------------------------------
message("Subsetting data...")
keep_pop <- !ids$excl_flag_c_int & !ids$excl_flag_chronic & !ids$excl_flag_death
keep <- which(keep_pop)

x <- x[keep, ]
ids <- ids[keep, ]

# Tuning -----------------------------------------------------------------------
message("Fitting GBMs...")
for(i in c(1,2)){
  message("Split ", i)
  split_var <- glue("sample_split{i}")
  keep_split <- which(ids[[split_var]])

  x_split <- x[keep_split,]
  ids_split <- ids[keep_split,]

  tuning_fp <- glue("tuning__gbm__stent_or_cabg_010_day__tested__{i}.rds")
  tuning <- readRDS(file.path(temp, tuning_fp))

  message("Identifying best parameters...")
  best_params <- tuning %>%
    unnest() %>%
    top_n(1, -logloss) %>%
    top_n(1, max_depth) %>%
    top_n(1, subsample) %>%
    top_n(1, colsample_bytree) %>%
    select(-logloss, everything())

  message("Best params:")
  iwalk(best_params, ~ message(.y, " = ", .x))

  folds <- map(
    unique(ids_split$train_fold), ~ which(ids_split$train_fold == .x)
  )

  message("Fitting gbm...")
  gbm <- xgboost(
    data = x_split,
    label = ids_split$stent_or_cabg_010_day,
    eta = best_params$eta,
    num_iterations = best_params$num_iterations,
    max_depth = best_params$max_depth,
    subsample = best_params$subsample,
    colsample_bytree = best_params$colsample_bytree,
    objective = "binary:logistic",
    nthread = n_distinct(ids_split$train_fold),
    nrounds = 10000L,
    early_stopping_rounds = 20L,
    verbose = 1
  )

  message("Saving model ", i, "...")
  save_fp <- file.path(
    temp,
    glue("gbm__stent_or_cabg_010_day__tested__{i}.rds")
  )
  xgb.save(gbm, save_fp)
}

# Done -------------------------------------------------------------------------
message("Done.")
